# 导入 BGE-M3 嵌入函数，用于生成文档和查询的向量表示
import torch.cuda
from milvus_model.hybrid import BGEM3EmbeddingFunction
# 导入 Milvus 相关类，用于操作向量数据库
from pymilvus import MilvusClient, DataType, AnnSearchRequest, WeightedRanker
# 导入 Document 类，用于创建文档对象
from langchain.docstore.document import Document
# 导入 CrossEncoder，用于重排序和 NLI 判断
from sentence_transformers import CrossEncoder
# 导入 hashlib 模块，用于生成唯一 ID 的哈希值
import hashlib
from rag_qa.core.document_loader import *
from base import logger, Config
import sys
import os

# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取 core 文件所在的目录的绝对路径
rag_qa_path = os.path.dirname(current_dir)
sys.path.insert(0,  rag_qa_path)
# 获取根目录文件所在的绝对位置
project_root = os.path.dirname(rag_qa_path)
sys.path.insert(0, project_root)

# # 查看路径
# print(f'current_dir --> {current_dir}')
# print(f'rag_qa_path --> {rag_qa_path}')
# print(f'project_root --> {project_root}')


config = Config()

# 定义 VectorStore 类，封装向量存储和检索功能
class VectorStore:
	# 初始化方法，设置向量存储的基本参数
	def __init__(self,
				 collection_name=config.MILVUS_COLLECTION_NAME,
				 host=config.MILVUS_HOST,
				 port=config.MILVUS_PROT,
				 database=config.MILVUS_DATABASE_NAME):
		# 设置 Milvus 集合名称
		self.collection_name = collection_name
		# 设置 Milvus 主机地址
		self.host = host
		# 设置 Milvus 端口号
		self.port = port
		# 设置 Milvus 数据库名称
		self.database = database
		# 设置日志记录器
		self.logger = logger
		# 将 device 设置为 cpu
		self.device = 'cpu'
		# 日志提醒使用的是什么设备
		self.logger.info(f"使用设备：{self.device}")
		# 初始化 BGE-Reranker 模型，用于重排序
		reranker_path = os.path.join(rag_qa_path, 'models', 'bge-reranker-large')
		# 初始化 BGE-M3 嵌入函数，使用 CPU 设备，不启用FP16
		m3_path = os.path.join(rag_qa_path, 'models', 'bge-m3')
		self.embedding_function = BGEM3EmbeddingFunction(
			model_name_or_path=m3_path,
			use_fp16=(self.device == 'cuda'),
			device=self.device
		)

		# 获取稠密向量的维度 1024
		self.dense_dim = self.embedding_function.dim['dense']
		# 初始化 Milvus 客户端，连接到指定主机和服务器
		self.client = MilvusClient(url=f'http://{self.host}:{self.port}', db_name=self.database)
		# 调用方法创建或加载 Milvus 集合
		self._create_or_load_collection()
	
	# 私有类方法
	def _create_or_load_collection(self):
		# 检查指定集合是否已经存在
		if not self.client.has_collection(self.collection_name):
			# 创建集合 Schema，禁用自动 ID，启用动态字段
			schema = self.client.create_schema(auto_id=False, enable_dynamic_field=True)
			# 添加 ID 字段，作为主键，VARCHAR 类型，最大长度 100
			schema.add_field(field_name='id', datatype=DataType.VARCHAR, is_primary=True, max_length=100)
			# 添加文本字段，VARCHAR 类型，最大长度 65535
			schema.add_field(field_name='text', datatype=DataType.VARCHAR, max_length=65535)
			# 添加稠密向量字段，FLOAT_VECTOR 类型，维度由嵌入函数指定
			schema.add_field(field_name='dense_vector', datatype=DataType.FLOAT_VECTOR, dim=self.dense_dim)
			# 添加稀疏向量字段，SPARSE_FLOAT_VECTOR类型
			schema.add_field(field_name='sparse_vector', datatype=DataType.SPARSE_FLOAT_VECTOR)
			# 添加父块 ID 字段，VARCHAR 类型，最大长度 100
			schema.add_field(field_name='parent_id', datatype=DataType.VARCHAR, max_length=100)
			# 添加父块内容字段，VARCHAR 类型，最大长度 65535
			schema.add_field(field_name='parent_content', datatype=DataType.VARCHAR, max_length=65535)
			# 添加学科类别字段，VARCHAR 类型，最大长度 50
			schema.add_field(field_name='source', datatype=DataType.VARCHAR, max_length=50)
			# 添加时间戳字段，VARCHAR 类型，最大长度 50
			schema.add_field(field_name='timestamp', datatype=DataType.VARCHAR, max_length=50)

	# 定义方法，向向量存储添加文档
	def add_documents(self, documents):
		# 提取所有文档的内容列表
		texts = [doc.page_content for doc in documents]

		# 使用 BGE-M3 嵌入函数生成文档的嵌入
		embeddings = self.embedding_function(texts)
		# 初始化空列表，存储插入的数据
		data = []
		# 遍历每个文档，带上索引 i
		for i, doc in enumerate(documents):
			# 生成文档内容的哈希值作为唯一 ID
			text_hash = hashlib.md5(doc.page_content.encode('utf-8').hexdigest())
			# 初始化一个稀疏向量的字典（Milvus 要求存储稀疏向量的格式）
			sparse_vector = {}

			try:
				row = embeddings['sparse'][i]
				if hasattr(row, 'col'):
					indices = row.col
					values = row.data
				else:
					indices = row.indices
					values = row.values
			except Exception as e:
				row = embeddings['sparse'].getrow(i)
				indices = row.indices
				values = row.data

			# 将索引和值进行配对，存储到字典中
			for idx, value in zip(indices, values):
				sparse_vector[idx] = value
			# 创建数据字典，包含所有字段
			data.append({
				'id': text_hash,
				'text': doc.page_content,
				'dense_vector': embeddings['dense'][i],
				'sparse_vector': sparse_vector,
				'parent_id': doc.metadata['parent_id'],
				'parent_content': doc.metadata['parent_content'],
				'source': doc.metadata.get('source', 'unknown'),
				'timestamp': doc.metadata.get('timestamp', 'unknown')
			})
		# 检查是否有数据需要插入
		if data:
			# 使用 upsert操作插入数据，覆盖重复 ID
			self.client.upsert(collection_name=self.collection_name, data=data)
			# 记录插入或更新的文档数量日志
			logger.info(f"已插入或更新 {len(data)} 个文档")

	def hybrid_search_with_rerank(self, query, k=config.RETRIEVAL_K, source_filter=None):
		# 使用 BGE-M3 嵌入函数生成查询的嵌入
		query_embeddings = self.embedding_function([query])
		# 获取查询的稠密向量
		dense_query_vector = query_embeddings['dense'][0]
		# 初始化查询的稀疏向量字典
		sparse_query_vector = {}
		try:
			row = query_embeddings['sparse'][0]
			if hasattr(row, 'col'):
				indices = row.col
				values = row.data
			else:
				indices = row.indices
				values = row.values
		except Exception as e:
			row = query_embeddings['sparse'].getrow(0)
			indices = row.indices
			values = row.data

		# 将索引和值配对，填充稀疏向量字典
		for idx, value in zip(indices, values):
			sparse_query_vector[idx] = value
			# 初始化过虑表达式，默认不过滤
			filter_expr = f"source == '{source_filter}'" if source_filter else ""
			# 创建稠密向量搜索请求
			dense_request = AnnSearchRequest(
				data=[dense_query_vector],
				anns_field='dense_vector',
				param={"metric_type": "IP", "params": {"nprobe": 10}},
				limit=k,
				expr=filter_expr
			)
			# 创建稀疏向量搜索请求
			sqarse_request = AnnSearchRequest(
				data=[sparse_query_vector],
				anns_field="sparse_vector",
				param={"metric_type": "IP", "params": {}},
				limit=k,
				expr=filter_expr
			)
			# 创建加权排序器，稀疏向量权重 0.7，稠密向量权重 1.0
			ranker = WeightedRanker(1.0, 0.7)
			# 执行混合搜索，返回 Top-K 结果
			results = self.client.hybrid_search(
				collection_name=self.collection_name,
				reqs=[dense_request, sqarse_request],
				ranker=ranker,
				limit=k,
				output_fields=['text', 'parent_id', 'parent_content', 'source', 'timestamp']
			)[0]
			# 将上述搜索结果进行 Document 对象封装，便于查询使用
			sub_chunks = [self._doc_from_hit(hit['entity']) for hit in results]
			# 从子块中提取去重的父文档
			parent_docs = self._get_unique_parent_docs(sub_chunks)
			# 如果只有 1 个文档或者没有，直接返回跳过重排序
			if len(parent_docs) < 2:
				return parent_docs[: config.CANDIDATE_M]
			# 如果没有父文档，返回空列表
			else:
				ranked_parent_docs = []
			# 返回前 m 个重排序后的文档
			return ranked_parent_docs[: config.CANDIDATE_M]

	def _get_unique_parent_docs(self, sub_chunks):
		# 初始化集合，用于存储已处理的父块内容（去重）
		parent_contents = set()
		# 初始化列表，用于存储唯一父文档
		unique_docs = []
		# 遍历所有子块
		for chunk in sub_chunks:
			#获取子块的父块内容，默认为子块内容
			parent_content = chunk.metadata.get('parent_content', chunk.page_content)
			# 检查父块内容是否非空且未重复
			if parent_content and parent_content not in parent_contents:
				# 创建新的 Document 对象，包含父块内容和元数据
				unique_docs.append(Document(page_content=parent_content, metadata=chunk.metadata))
				# 将父块内容添加到去重集合
				parent_contents.add(parent_content)
		# 返回去重后的父文档列表
		return unique_docs

	# 定义类似有方法，从 Milvus 查询结果创建Document 对象
	def _doc_from_hit(self, hit):
		# 创建并返回 Document 对象，填充内容和元数据
		return Document(
			page_content=hit.get('text'),
			metadata={
				'parent_id': hit.get('parent_id'),
				'parent_content': hit.get('parent_content'),
				'source': hit.get('source'),
				'timestamp': hit.get('timestamp')
			}
		)


if __name__ == "__main__":
	vector_store = VectorStore()
	# directory_path = '/Users/chan/projects/Itcast_qa_system/rag_qa/data/ai_data'
	# print(f"embedding_function.dim--》{vector_store.embedding_function.dim}")
	# documents = process_documents(directory_path)
	# vector_store.add_documents(documents)

	query = "AI学科学费是多少？"
	results = vector_store.hybrid_search_with_rerank(query, source_filter='ai')
	print(f'results-->{results}')
	print(f'results-->{len(results)}')
